"""Curvature analysis of strongly-regular graphs."""

import argparse
import logging
import warnings

import gudhi as gd
import gudhi.wasserstein

import networkx as nx
import numpy as np
import pandas as pd

from python_log_indenter import IndentedLoggerAdapter

import sys
from curvature import forman_curvature
from curvature import ollivier_ricci_curvature
from curvature import resistance_curvature

from utils import propagate_edge_attribute_to_nodes
from utils import propagate_node_attribute_to_edges

from scipy.stats import wasserstein_distance


def calculate_persistent_homology(G, k=3):
    """Calculate persistent homology of graph clique complex."""
    st = gd.SimplexTree()

    for v, w in G.nodes(data=True):
        weight = w["curvature"]
        st.insert([v], filtration=weight)

    for u, v, w in G.edges(data=True):
        weight = w["curvature"]
        st.insert([u, v], filtration=weight)

    st.make_filtration_non_decreasing()
    st.expansion(k)
    persistence_pairs = st.persistence()

    diagrams = []

    for dimension in range(k + 1):
        diagram = [
            (c, d) for dim, (c, d) in persistence_pairs if dim == dimension
        ]

        diagrams.append(diagram)

    return diagrams


def laplacian_eigenvalues(G):
    """Calculate Laplacian and return eigenvalues."""
    with warnings.catch_warnings():
        warnings.simplefilter(action="ignore", category=FutureWarning)
	#print(nx.Laplacian_spectrum(G))
        return nx.laplacian_spectrum(G)
    


def prob_rw(G, node, node_to_index):
    """Probability measure based on random walk probabilities."""
    import scipy as sp

    A = nx.to_scipy_sparse_array(G, format="csr").todense()
    n, m = A.shape
    D = sp.sparse.csr_array(
        sp.sparse.spdiags(A.sum(axis=1), 0, m, n, format="csr")
    ).todense()

    P = np.linalg.inv(D) @ A

    values = np.zeros(len(G.nodes))
    values[node_to_index[node]] = 1.0

    x = values
    values = x + P @ x + P @ P @ x

    values /= values.sum()
    return values


def prob_two_hop(G, node, node_to_index):
    """Probability measure based on two-hop neighbourhoods."""
    alpha = 0.5
    values = np.zeros(len(G.nodes))
    values[node_to_index[node]] = alpha

    subgraph = nx.ego_graph(G, node, radius=2)

    w = 0.25

    direct_neighbors = list(G.neighbors(node))
    for neighbor in direct_neighbors:
        values[node_to_index[neighbor]] = (1 - alpha) * w

    w = 0.05

    for neighbor in subgraph.nodes():
        if neighbor not in direct_neighbors and neighbor != node:
            index = node_to_index[neighbor]
            values[index] = (1 - alpha) * w

    # TODO: Only necessary because I am making my life easy here.
    values /= values.sum()
    return values


def run_experiment(graphs, curvature_fn, prob_fn, k, node_level=False, google_matrix=False):
    """Run experiment on all graphs for a given probability function."""
    for graph in graphs:
        if prob_fn is not None:
            curvature = curvature_fn(graph, prob_fn=prob_fn)
        else:
            if google_matrix:
                curvature = curvature_fn(graph, google_matrix = True)
            else:
                curvature = curvature_fn(graph)

        # Assign node-level attribute
        if node_level:
            curvature = {v: c for v, c in zip(graph.nodes(), curvature)}
            nx.set_node_attributes(graph, curvature, "curvature")

        # Assign edge-based attribute. This is the normal assignment
        # procedure whenever we are dealing with proper curvature
        # measurements.
        else:
            curvature = {e: c for e, c in zip(graph.edges(), curvature)}
            nx.set_edge_attributes(graph, curvature, "curvature")

    n_pairs = 0
    all_pairs = len(graphs) * (len(graphs) - 1) / 2

    for i, Gi in enumerate(graphs):
        for j, Gj in enumerate(graphs):
            if i < j:
                access_fn = (
                    nx.get_node_attributes
                    if node_level
                    else nx.get_edge_attributes
                )

                ci = list(access_fn(Gi, "curvature").values())
                cj = list(access_fn(Gj, "curvature").values())

                n_pairs += wasserstein_distance(ci, cj) > 1e-8

    log.add()
    log.info(f"Distinguishing {n_pairs}/{int(all_pairs)} pairs (raw)")
    log.sub()

    error_rate_raw = 1.0 - n_pairs / all_pairs

    persistence_diagrams = []

    for graph in graphs:
        if node_level:
            propagate_node_attribute_to_edges(graph, "curvature")
        else:
            propagate_edge_attribute_to_nodes(
                graph, "curvature", pooling_fn=lambda x: -1
            )

        diagrams = calculate_persistent_homology(graph, k=k)
        persistence_diagrams.append(diagrams)

    n_pairs = 0

    for i, Gi in enumerate(graphs):
        for j, Gj in enumerate(graphs):
            if i < j:
                distance = 0.0
                for D1, D2 in zip(
                    persistence_diagrams[i], persistence_diagrams[j]
                ):
                    distance += gudhi.wasserstein.wasserstein_distance(
                        np.asarray(D1), np.asarray(D2)
                    )

                n_pairs += distance > 1e-8

    log.add()
    log.info(f"Distinguishing {n_pairs}/{int(all_pairs)} pairs (TDA)")
    log.sub()

    error_rate_tda = 1.0 - n_pairs / all_pairs
    return error_rate_raw, error_rate_tda


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("FILE", type=str, help="Input file (in `.g6` format)")

    parser.add_argument(
        "-o",
        "--output",
        type=str,
        help="If set, store output in specified file.",
    )

    parser.add_argument(
        "-k",
        type=int,
        default=2,
        help="Specifies maximum expansion dimension for graphs.",
    )

    args = parser.parse_args()
    graphs = nx.read_graph6(args.FILE)

    prob_fns = [
        ("default", None),
        ("random_walk", prob_rw),
        ("two_hop", prob_two_hop),
    ]

    # Will collect rows for the output of the experimental table later
    # on. This makes it possible to "fire and forget" some jobs on the
    # cluster.
    rows = []

    logging.basicConfig(format="%(message)s", level=logging.INFO)
    log = IndentedLoggerAdapter(logging.getLogger(__name__))

    log.info(f"Running experiment with {len(graphs)} graphs")

    log.info("Laplacian spectrum")
    log.add()

    e1, e2 = run_experiment(
        graphs, laplacian_eigenvalues, None, args.k, node_level=True
    )

    rows.append(
        {
            "name": "Laplacian spectrum",
            "raw": [e1],
            "tda": [e2],
        },
    )

    log.sub()

    log.info("Forman--Ricci curvature")
    log.add()

    e1, e2 = run_experiment(graphs, forman_curvature, None, args.k)

    rows.append(
        {
            "name": "Forman--Ricci curvature",
            "raw": [e1],
            "tda": [e2],
        },
    )

    log.sub()
    log.info("Ollivier--Ricci curvature")

    log.add()

    for name, prob_fn in prob_fns:
        log.add()
        log.info(f"Probability measure: {name}")

        e1, e2 = run_experiment(
            graphs, ollivier_ricci_curvature, prob_fn, args.k
        )

        rows.append(
            {
                "name": "Ollivier--Ricci curvature",
                "prob": name,
                "raw": [e1],
                "tda": [e2],
            },
        )

        log.sub()

    log.sub()

    log.info("Resistance curvature")
    log.add()

    e1, e2 = run_experiment(graphs, resistance_curvature, None, args.k)

    rows.append({"name": "Resistance curvature", "raw": [e1], "tda": [e2]})

    log.sub()
    
    e1, e2 = run_experiment(graphs, ollivier_ricci_curvature, None, args.k, google_matrix = True)

    rows.append({"name": "Ollivier--Ricci curvature",
                "prob": 'google_matrix', "raw": [e1], "tda": [e2]})

    log.sub()

    rows = [pd.DataFrame.from_dict(row) for row in rows]
    df = pd.concat(rows, ignore_index=True)
    print(df)

    if args.output is not None:
        df.to_csv(args.output, index=False)
